#!/usr/bin/env python
# -*- coding: utf-8 -*-
#from __future__ import division, with_statement
'''
Copyright 2013, 陈同 (chentong_biology@163.com).  
===========================================================
'''
__author__ = 'chentong & ct586[9]'
__author_email__ = 'chentong_biology@163.com'
#=========================================================
desc = '''
Functional description:
    This is designed to integrate the CDS information predicted using
    transdecoder with transcripts assembled using cufflinks.
'''

import sys
import os
from json import dumps as json_dumps
from time import localtime, strftime 
timeformat = "%Y-%m-%d %H:%M:%S"
from optparse import OptionParser as OP
#from bs4 import BeautifulSoup

#reload(sys)
#sys.setdefaultencoding('utf8')

#from multiprocessing.dummy import Pool as ThreadPool

debug = 0

def fprint(content):
    print json_dumps(content,indent=1)

def cmdparameter(argv):
    if len(argv) == 1:
        global desc
        print >>sys.stderr, desc
        cmd = 'python ' + argv[0] + ' -h'
        os.system(cmd)
        sys.exit(1)
    usages = "%prog -i file"
    parser = OP(usage=usages)
    parser.add_option("-i", "--input-file", dest="filein",
        metavar="FILEIN", help="Cuffcompare or cufflink GTF. \
Any type of GTF would suitable. The GTF should be sorted by \
transcript name.")
    parser.add_option("-t", "--translate-gtf", dest="trans",
        help="Normally GTFs generated by transdecoder. Each transcript \
will be treated as a scaffold for de-novo assembling.")
    parser.add_option("-c", "--correct", dest="correct",
        default=0, help="Correct the pep file, cds file (fasta format) generated \
by Transdecoder according to the strand of transcripts. \
Since come transcripts assembled using cufflinks have no direction \
information, we run Transdecoder in non-strand mode. \
Here we only select direct translation of transcripts \
which direction is known in cufflink GTFs. \
Default 0. Accept 1 meaning execute correction.")
    parser.add_option("-v", "--verbose", dest="verbose",
        default=0, help="Show process information")
    parser.add_option("-d", "--debug", dest="debug",
        default=False, help="Debug the program")
    (options, args) = parser.parse_args(argv[1:])
    assert options.filein != None, "A filename needed for -i"
    return (options, args)
#--------------------------------------------------------------------

def cds(posL, start, end, strand):
    if strand == '-':
        posL.reverse()
    #print strand
    #print posL
    cdsSeg = [posL[start]]
    last = posL[start]
    #print >>sys.stderr, posL, len(posL)
    #print >>sys.stderr, len(posL)
    #print >>sys.stderr, start, end
    for pos in range(start+1, end):
        #print >>sys.stderr, pos, 
        current = posL[pos]
        if (current-last>1) or (current-last<-1):
            cdsSeg.append(last)
            cdsSeg.append(current)
        last = current
    #------------append the last one--------
    cdsSeg.append(current)
    #------------append the last one--------
    #if strand == '-':
    #    cdsSeg.reverse()
    len_cdsSeg = len(cdsSeg)
    assert len_cdsSeg % 2 == 0
    tmpL = []
    for i in range(0, len_cdsSeg, 2):
        start = str(cdsSeg[i])
        end   = str(cdsSeg[i+1])
        if strand == '-':
            start, end = end, start
        tmpL.append([start, end])
    if strand == '-':
        tmpL.reverse()
    cdsSeg = tmpL
    return cdsSeg
#----------END cds----------------------

def output(trannoL, cdsSeg, transcript_strand, translation_strand, count):
    '''
    trannoL = [line1, line2, line3]
    cdsSeg  = [(start1, end1), (start2, end2)]
    '''
    nameD = {}
    for line in trannoL:
        lineL = line.split("\t")
        name = lineL[8]
        #print name
        nameD = dict([i.split(" ") for i in name.split('; ')])
        if transcript_strand == '.':
            lineL[6] = translation_strand
        if count > 1:
            nameD["transcript_id"] = nameD["transcript_id"][:-1]+":%d\"" % count
            lineL[8] = '; '.join(["gene_id %s" % nameD["gene_id"], 
                "transcript_id %s" % nameD["transcript_id"],
                "exon_number %s" % nameD['exon_number'], 
                "oId %s" % nameD['oId'], 
                "tss_id %s" % nameD['tss_id']])
            #lineL[8] = name.replace('"; exon_number',\
            #    ':%d"; exon_number' % count)
        #_--------------------------
        print '\t'.join(lineL)
    #----------END output original-------------
    #----------Output CDS-------------
    cds_count = 1
    for start, end in cdsSeg:
        lineL[3] = start
        lineL[4] = end
        lineL[2] = "CDS"
        lineL[8] = '; '.join(["gene_id %s" % nameD["gene_id"], 
            "transcript_id %s" % nameD["transcript_id"],
            "CDS_number \"%d\"" % cds_count, 
            "oId %s" % nameD['oId'], 
            "tss_id %s" % nameD['tss_id']])
        print '\t'.join(lineL)
        cds_count += 1
    #----------Output CDS-------------
#-----------END output---------------

def generateCDS(trannoL, transL, depletedD):
    '''
    trannoL = [line1, line2, line3]
    transL = [[(start_1, end_1), '+', name], 
              [(start_2, end_2), '+', name]]
    '''
    exonL = []
    #print trannoL
    for line in trannoL:
        lineL = line.split()
        if lineL[2] == 'exon':
            lineL[3] = int(lineL[3])
            lineL[4] = int(lineL[4])
            exonL.append(lineL)
            strand = lineL[6]
    #--------------------------------------
    exonL.sort(key=lambda x: x[3])
    #print exonL
    posL = [j for i in exonL for j in range(i[3], i[4]+1)]
    count = 0
    #print transL
    for trans in transL:
        start = trans[0][0] -  1
        end   = trans[0][1]  
        if strand in ['+', '-']:
            if trans[1] == '+':
                cdsSeg = cds(posL[:], start, end, strand)
                count += 1
                output(trannoL, cdsSeg, strand, strand, count)
            else:
                depletedD[trans[2]] = 0
        elif strand == '.':
            if trans[1] == '+':
                cdsSeg = cds(posL[:], start, end, '+')
                count += 1
                #print cdsSeg
                output(trannoL, cdsSeg, '.', '+', count)
            elif trans[1] == '-':
                cdsSeg = cds(posL[:], start, end, '-')
                count += 1
                #print cdsSeg
                output(trannoL, cdsSeg, '.', '-', count)
        else:
            print >>sys.stderr, "Unexpected strand %s" % line
            sys.exit(1)
        #---------END strand ---------------
    #----------END for trans------------
    #------For GTF strand inconsistent with translation strand
    if count == 0:
        print "".join(trannoL),
#---------------------------------------------

def main():
    options, args = cmdparameter(sys.argv)
    #-----------------------------------
    file = options.filein
    trans  = options.trans 
    verbose = options.verbose
    correct = int(options.correct)
    global debug
    debug = options.debug
    #-----------------------------------
    if file == '-':
        fh = sys.stdin
    else:
        fh = open(file)
    #--------------------------------
    transD = {}
    for line in open(trans):
        lineL = line.split()
        if not lineL: # skip blank line
            continue
        type = lineL[2]
        if type == 'CDS':
            tr = lineL[0]
            start = int(lineL[3])
            end   = int(lineL[4])
            strand = lineL[6]
            posL = (start, end)
            name = lineL[8].split('=')[2].strip()

            if tr not in transD:
                transD[tr] = [[posL, strand, name]]
            else:
                transD[tr].append([posL, strand, name])
        #-------------------------------------------
    #-----------------------------------------
    depletedD = {}
    tr = ""
    tmpL = []
    for line in fh:
        line = line.strip()
        lineL = line.split('"')
        if tr and lineL[3] != tr:
            if tr in transD:
                generateCDS(tmpL, transD[tr], depletedD)
            else:
                print '\n'.join(tmpL)
            tmpL = []
        tr = lineL[3]
        tmpL.append(line)    
    #--------The last --------------
    if tr:
        if tr in transD:
            generateCDS(tmpL, transD[tr], depletedD)
        else:
            print '\n'.join(tmpL)
        tmpL = []
    #-------------END reading file----------
    #----close file handle for files-----
    if file != '-':
        fh.close()
    #-----------end close fh-----------
    if correct:
        fileL = [trans.replace('gff3', 'pep'), 
                trans.replace('gff3','cds')]
        for file in fileL:
            os.system("/bin/mv -f %s %s.bak" % (file, file))
            fh = open(file, 'w')
            for line in open(file+'.bak'):
                if line[0] == '>':
                    keep = 0
                    key = line.split()[0][1:]
                    if key not in depletedD:
                        keep = 1
                if keep == 1:
                    print >>fh, line,
            fh.close()
    ###--------multi-process------------------
    #pool = ThreadPool(5) # 5 represents thread_num
    #result = pool.map(func, iterable_object)
    #pool.close()
    #pool.join()
    ###--------multi-process------------------
    if verbose:
        print >>sys.stderr,\
            "--Successful %s" % strftime(timeformat, localtime())

if __name__ == '__main__':
    startTime = strftime(timeformat, localtime())
    main()
    endTime = strftime(timeformat, localtime())
    fh = open('python.log', 'a')
    print >>fh, "%s\n\tRun time : %s - %s " % \
        (' '.join(sys.argv), startTime, endTime)
    fh.close()
    ###---------profile the program---------
    #import profile
    #profile_output = sys.argv[0]+".prof.txt")
    #profile.run("main()", profile_output)
    #import pstats
    #p = pstats.Stats(profile_output)
    #p.sort_stats("time").print_stats()
    ###---------profile the program---------


